import numpy as np
from scipy.optimize import minimize
import pandas as pd
from IPython.display import clear_output
from scipy.optimize import fsolve
import time
import matplotlib.pyplot as plt
import networkx as nx
import copy



def F_a_f(a, f, y, u, Arguments):
    
    """
    This function calculate the component F_{a, f} of the velocity matrix F (see formula (2) in the 
    main manuscript).
    
    Input:
    a         - the first index of F_{a, f} (opinion dimension)
    f         - the second index of F_{a, f} (type dimension)
    y         - the matrix of state variables
    u         - the matrix of control variables
    Arguments - hyperparameters
    """
    
    m = Arguments['m']
    M = Arguments['M']
    
    n = Arguments['n']
    
    Transition_Matrices = Arguments['Transition_Matrices']
    
    Velocity = 0  # the output variable
    
    #--------------------------------------------------------
    
    n_f = u[:, f].sum()  # the fraction of stubborn agents that communicate with type f agents at this time moment
    
    #--------------------------------------------------------
    
    for s in range(m):
        for l in range(m):
            
            for r in range(M):
                
                Velocity = Velocity + y[s, f] * y[l, r] * Transition_Matrices[f'{f}-{r}'][s][l, a]
                
            Velocity = Velocity + y[s, f] * u[l, f] * Transition_Matrices[f'{f}-{M}'][s][l, a]
            
    Velocity = Velocity - y[a, f] * (n + n_f)
    
    return Velocity



def F(y, u, Arguments):
    
    """
    This function calculate the velocity matrix F (see formula (2) in the 
    main manuscript).
    
    Input:
    y         - the matrix of state variables
    u         - the matrix of control variables
    Arguments - hyperparameters
    """
    
    
    m = Arguments['m']
    M = Arguments['M']
    
    Velocities = np.zeros((m, M))  # the output variable
    
    for a in range(m):
        for f in range(M):
            Velocities[a, f] = F_a_f(a, f, y, u, Arguments)
            
    return Velocities







def Objective_Function_Adjustment(Dependent_Variable, weights):
    
    """
    This function computes the L2 norm of the matrix "Dependent_Variable - weights"
    """
    
    Output = ((Dependent_Variable - weights) ** 2).sum()
    
    return Output



def State_Constraint(y, f, Arguments):
    
    """
    This function computes the equality constraints and ensures that the total fraction of
    ordinary agents of type f is equal to n_types[f]
    
    Input:
    Dependent_Variable 
    f                  - the index of the type
    Arguments          - hyperparameters
    """
    
    n_types = Arguments['n_types']  # the array of length M where populations of agent types are stored
    
    m = Arguments['m']
    M = Arguments['M']    
    
    #--------------------------------------------------------
    
    y_Matrix = y.reshape(m, M)
    
    return y_Matrix[:, f].sum() - n_types[f]



def Adjust_State_Matrix(y_new, Arguments):
    
    """
    This function corrects the matrix of state veriables y_new to ensure that the type constraints 
    (all state variables are nonnegative and the columns sums of y_new
    are equal to n_{1}, \ldots, n_{M}) are satisfied.
    
    This is achieved by performing a minimization procedure, in which we find those matrix of state variables, 
    which is most close to y_new (in terms of L2-norm) and, besides, meets the type constraints
    
    Input:
    y_new     - the matrix of state variables
    Arguments - hyperparameters
    """
    
    m = Arguments['m']
    M = Arguments['M']
    
    # Reshape the matrix (procedure minimize from scipy.optimize needs vector inputs)
    
    #--------------------------------------------------------
        
    y_new = y_new.reshape(m*M, )  
    
    # Define bound-type constraints

    Bounds = [(0, None) for i in range(m*M)]
    
    Bounds = tuple(Bounds)
    
    # Type constraints
    
    Constraints = [{'type': 'eq', 'fun': State_Constraint, 'args': (f, Arguments, )} for f in range(M)]
    
    Constraints = tuple(Constraints)
    
    # Run minimization

    res = minimize(fun=Objective_Function_Adjustment, 
                   x0=y_new, 
                   bounds=Bounds, 
                   constraints=Constraints, 
                   args=(y_new, ))
    
    return res.x.reshape(m, M)



def RKM_Iteration_Dyn_System(y_previous, u_previous, Arguments):
    
    """
    This function performs one Runge-Kutte iteration (4th order method)
    
    Input:
    y_previous - the values of the state matrix at the previous time moment
    u_previous - the values of the control matrix at the previous time moment
    Arguments  - hyperparameters
    
    Output:   
    y_new      - the values of the state matrix at the next time moment
    """
    
    m = Arguments['m']
    M = Arguments['M']
    
    # The elements of the Butcher tableau
    
    #b = Arguments['b']
    #s = Arguments['s']
    
    # The grid
    
    Step = Arguments['Step']    
    
    #--------------------------------------------------------
    
    K1 = Step * F(y_previous, u_previous, Arguments)
    K2 = Step * F(y_previous+K1/2, u_previous, Arguments)
    K3 = Step * F(y_previous+K2/2, u_previous, Arguments)
    K4 = Step * F(y_previous+K3, u_previous, Arguments)
        
    y_new = y_previous + K1/6 + K2/3 + K3/3 + K4/6

    return y_new

    

def RKM_Dyn_System(y_0, u_t, Arguments, iteration_counter, adjust=True):
    
    """
    This function solves the Cauchy problrm (3), (4) from the main manuscript
    using the Runge-Kutta method of 4th order
    
    Input:
    y_0               - the starting point of the system (see formula (4) from the main manuscript)
    u_t               - the array of control matrices (control variable u (\tau) from the main manuscript) 
    Arguments         - hyperparameters
    iteration_counter - the number of the current FBS method iteration
    adjust            - this indicator signify if the adjusting procedure (make state variables meet type 
                        constraints, if essential) should be implemented 
    
    Output:
    y_t               - the array of state matrices (y (\tau) in the main manuscript)
    """
    
    # The grid
    
    Step = Arguments['Step']
    Grid = Arguments['Grid']
    
    #--------------------------------------------------------
        
    # Initialize the output variable
    
    y_t = []
    y_t.append(y_0)    
    
    # The number of Runge-Kutta iterations
    
    N_Iter = len(Grid)-1
    
    # Start iterations
    
    for i in range(N_Iter):
        
        # Display the current iteration
        
        #print(f'Iteration {iteration_counter} ... Solving dynamic system: {i}/{N_Iter}')
        #clear_output(wait=True)
        
        # Perform the Runge-Kutta iteration
        
        y_previous = y_t[i]
        u_previous = u_t[i]
        
        y_new = RKM_Iteration_Dyn_System(y_previous, u_previous, Arguments)
 
        # Need to ensure that the type constraints 
        # (all state variables are nonnegative and the columns sums of y (\tau) 
        # are equal to n_{1}, \ldots, n_{M}) are satisfied

        if adjust:
            y_new = Adjust_State_Matrix(y_new, Arguments)
        else:
            pass
        
        # Insert new estimation into the output array
        
        y_t.append(y_new)
        
    return y_t





                
                




